// cmd/protoc-gen-custom/main.go
package main

import (
	"bytes"
	"fmt"
	"html/template"
	"os"
	"strings"

	"github.com/gzydong/go-chat/api/pb/common"
	"google.golang.org/genproto/googleapis/api/annotations"
	"google.golang.org/protobuf/compiler/protogen"
	"google.golang.org/protobuf/types/descriptorpb"
	"google.golang.org/protobuf/types/pluginpb"

	// 导入 annotations 来处理 HTTP 注解
	"google.golang.org/protobuf/proto"
)

const version = "1.0.0"

// HTTPRule 表示 HTTP 规则
type HTTPRule struct {
	Method string
	Path   string
	Body   string
}

func main() {
	if len(os.Args) == 2 && os.Args[1] == "--version" {
		_, _ = fmt.Fprintf(os.Stdout, "protoc-gen-custom %v\n", version)
		return
	}

	opt := &protogen.Options{}
	opt.Run(func(gen *protogen.Plugin) error {
		gen.SupportedFeatures = uint64(pluginpb.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL)
		for _, f := range gen.Files {
			if !f.Generate {
				continue
			}
			generateFile(gen, f)
		}
		return nil
	})
}

func generateFile(gen *protogen.Plugin, file *protogen.File) {
	if len(file.Services) == 0 {
		return
	}

	filename := file.GeneratedFilenamePrefix + ".bff.go"
	g := gen.NewGeneratedFile(filename, file.GoImportPath)

	g.P("// Code generated by protoc-gen-bff. DO NOT EDIT.")
	g.P("// Version: ", version)
	g.P()
	g.P("package ", file.GoPackageName)
	g.P()

	// 生成导入声明
	generateImports(g, file)

	// 生成服务代码
	for _, service := range file.Services {
		generateService(g, service)
	}
}

func generateImports(g *protogen.GeneratedFile, file *protogen.File) {
	g.P("import (")
	g.P(`    "context"`)
	g.P()
	g.P(`    "github.com/gin-gonic/gin"`)
	g.P(")")
	g.P()
}

func generateService(g *protogen.GeneratedFile, service *protogen.Service) {
	serviceName := service.GoName

	// 生成服务接口
	g.P("// ", "I"+serviceName, "Handler BFF 接口")
	g.P("type ", "I"+serviceName, "Handler interface {")

	for _, method := range service.Methods {
		methodName := method.GoName
		inputType := g.QualifiedGoIdent(method.Input.GoIdent)
		outputType := g.QualifiedGoIdent(method.Output.GoIdent)

		g.P("    ", methodName, "(ctx context.Context, req *", inputType, ") (*", outputType, ", error)")
	}

	g.P("}")
	g.P()

	// 生成注册函数
	g.P("// Register", serviceName, "Handler 注册服务路由处理器")
	g.P("func Register", serviceName, "Handler(r gin.IRoutes, s interface{")
	g.P("    ShouldProto(c *gin.Context, in any) error")
	g.P("    ErrorResponse(c *gin.Context, err error)")
	g.P("    SuccessResponse(c *gin.Context, data any)")
	g.P("}, handler ", "I"+serviceName, "Handler) {")
	g.P("    if handler == nil {")
	g.P("        panic(\"handler is nil\")")
	g.P("    }")
	g.P()

	for _, method := range service.Methods {
		httpRule := getMethodHTTPRule(method)
		inputType := g.QualifiedGoIdent(method.Input.GoIdent)
		//outputType := g.QualifiedGoIdent(method.Output.GoIdent)

		body := generateHandler(httpRule, inputType, method.GoName)
		if body != nil {
			_, _ = g.Write(body)
		}

		g.P()
	}

	g.P("}")
	g.P()
}

var handler = `
	r.POST("{{.Uri}}", func(c *gin.Context) {
		var in {{.Request}}
		if err := s.ShouldProto(c, &in); err != nil {
			s.ErrorResponse(c, err)
			return
		}

		data, err := handler.{{.FuncName}}(c.Request.Context(), &in)
		if err != nil {
			s.ErrorResponse(c, err)
			return
		}

		s.SuccessResponse(c, data)
	})
`

func generateHandler(rule *HTTPRule, inputType string, funcName string) []byte {
	parse, err := template.New("tpl").Parse(handler)
	if err != nil {
		return nil
	}

	buf := &bytes.Buffer{}
	_ = parse.Execute(buf, map[string]any{
		"Uri":      rule.Path,
		"Request":  inputType,
		"FuncName": funcName,
	})

	return buf.Bytes()
}

// nolint 获取服务 option 配置
func getServiceConfig(service *protogen.Service) *common.ServiceConfig {
	opt := service.Desc.Options()
	if opt == nil {
		return nil
	}

	if proto.HasExtension(opt, common.E_ServiceConfig) {
		ext := proto.GetExtension(opt, common.E_ServiceConfig)
		if ext == nil {
			return nil
		}

		return ext.(*common.ServiceConfig)
	}

	return nil
}

// nolint 获取 Method option 配置
func getCommonMethodConfig(method *protogen.Method) *common.MethodConfig {
	// 解析自定义选项
	opts := method.Desc.Options()
	if opts != nil {
		if proto.HasExtension(opts, common.E_MethodConfig) {
			methodConfig := proto.GetExtension(opts, common.E_MethodConfig).(*common.MethodConfig)
			return methodConfig
		}
	}

	return nil
}

// getMethodHTTPRule 从方法中提取 HTTP 规则
func getMethodHTTPRule(method *protogen.Method) *HTTPRule {
	// 解析真实的 google.api.http 注解
	options := method.Desc.Options().(*descriptorpb.MethodOptions)
	if options != nil {
		// 检查是否有 google.api.http 注解
		if proto.HasExtension(options, annotations.E_Http) {
			httpRule := proto.GetExtension(options, annotations.E_Http).(*annotations.HttpRule)
			if httpRule != nil {
				return parseHTTPRule(httpRule)
			}
		}
	}

	// 如果没有找到注解，使用默认规则
	defaultPath := generateRoutePath(method.Parent.GoName, method.GoName)
	return &HTTPRule{
		Method: "post",
		Path:   defaultPath,
		Body:   "*",
	}
}

// parseHTTPRule 解析 HTTP 规则
func parseHTTPRule(httpRule *annotations.HttpRule) *HTTPRule {
	rule := &HTTPRule{
		Body: httpRule.Body,
	}

	// 根据不同的 HTTP 方法设置
	switch pattern := httpRule.Pattern.(type) {
	case *annotations.HttpRule_Get:
		rule.Method = "get"
		rule.Path = pattern.Get
	case *annotations.HttpRule_Post:
		rule.Method = "post"
		rule.Path = pattern.Post
	case *annotations.HttpRule_Put:
		rule.Method = "put"
		rule.Path = pattern.Put
	case *annotations.HttpRule_Delete:
		rule.Method = "delete"
		rule.Path = pattern.Delete
	case *annotations.HttpRule_Patch:
		rule.Method = "patch"
		rule.Path = pattern.Patch
	default:
		rule.Method = "post"
		rule.Path = "/"
	}

	return rule
}

func generateRoutePath(serviceName, methodName string) string {
	// 生成默认路由路径
	serviceNameLower := strings.ToLower(serviceName)
	methodNameLower := strings.ToLower(methodName)

	return fmt.Sprintf("/%s/%s", serviceNameLower, methodNameLower)
}
